from typing import List

from data import *
from julia import Main
import math
from utils import *
from sklearn.linear_model import Ridge
import numpy.matlib as matlib


def full_obs_ite_multi_trials(X, y0, y1, params, nt):
    n, d = X.shape
    teehat = np.zeros((y0.shape[0], nt))
    for i in range(nt):
        teehat[:,i] = np.squeeze(full_obs_ite_est(X, y0, y1, params))
    tee = np.expand_dims(y1 - y0, axis=1)
    teemat = matlib.repmat(tee, 1, nt)
    rmses = np.linalg.norm(teehat - tee, axis=0) / np.sqrt(n)
    rmse_mean = sum(rmses) / nt
    rel_err = rmses / np.linalg.norm(y1-y0)
    rel_mean = rmse_mean / np.linalg.norm(y1-y0)
    rmse_std = np.sqrt(sum(np.multiply(rmses - rmse_mean, rmses - rmse_mean)) / nt)
    rel_std = np.sqrt(sum(np.multiply(rel_err - rel_mean, rel_err - rel_mean)) / nt)
    return rmse_mean, rmse_std, rel_mean, rel_std


def full_obs_ite_est(X, y0, y1, params):
    if params["method"] == 'rand_vec_ite':
        return rand_vec_ite(X, y0, y1, params)
    elif params["method"] == 'simple_lev_ite':
        return simple_lev_approach_ite(X, y0, y1, params)
    elif params["method"] == 'subsample_rand_vec_ite':
        return subsample_rand_vec_ite(X, y0, y1, params)
    elif params["method"] == 'subsample_simple_lev_ite':
        return subsample_simple_lev_approach_ite(X, y0, y1, params)
    elif params["method"] == 'baseline':
        return baseline(X, y0, y1, params)


def simple_lev_approach_ite(X, y0, y1, params):
    n, d = X.shape
    zeta = max(np.linalg.norm(X, axis=1))
    y0 = np.expand_dims(y0, axis=1)
    y1 = np.expand_dims(y1, axis=1)
    lamda = 100 * np.log(y0.shape[0]) * zeta * zeta
    z = np.random.choice([-1, 1], size=(y0.shape[0],1), p=[0.5,0.5]) # S is 1 and Sbar is -1
    y = ((y0 + y1) - np.multiply(z, y0 - y1)) / 2
    X1 = X[np.squeeze(z==-1),:]
    yr1 = y[z==-1]
    X2 = X[np.squeeze(z==1),:]
    yr2 = y[z==1]
    clf1 = Ridge(alpha=lamda, fit_intercept=False)
    clf1.fit(X1, yr1)
    clf2 = Ridge(alpha=lamda, fit_intercept=False)
    clf2.fit(X2, yr2)
    y0hat = np.expand_dims(clf1.predict(X), axis = 1)
    y1hat = np.expand_dims(clf2.predict(X), axis = 1)
    return y1hat - y0hat


def rand_vec_ite(X, y0, y1, params):
    n, d = X.shape
    zeta = max(np.linalg.norm(X, axis=1))
    y0 = np.expand_dims(y0, axis=1)
    y1 = np.expand_dims(y1, axis=1)
    lamda = 0 #params["lamda"] #np.log(y0.shape[0]) * zeta * zeta
    z = np.random.choice([-1, 1], size=(y0.shape[0],1), p=[0.5, 0.5]) # Z+ is 1 and Z- is -1
    y = ((y1 - y0) + np.multiply(z, y0 + y1)) / 2
    clf = Ridge(alpha=lamda, fit_intercept=False)
    clf.fit(2 * X, 2 * y)
    return clf.predict(2 * X)


def subsample_simple_lev_approach_ite(X, y0, y1, params):
    n, d = X.shape
    alpha = params["alpha"]
    zeta = max(np.linalg.norm(X, axis=1))
    y0 = np.expand_dims(y0, axis=1)
    y1 = np.expand_dims(y1, axis=1)
    lamda = 0 #100 * np.log(y0.shape[0]) * zeta * zeta / alpha
    z = np.random.choice([-1, 1, 0], size=(y0.shape[0],1), p=[0.5 * alpha, 0.5 * alpha, 1 - alpha]) # S is 1 and Sbar is -1
    y = ((y0 + y1) - np.multiply(z, y0 - y1)) / 2
    X1 = X[np.squeeze(z==-1),:]
    yr1 = y[z==-1]
    X2 = X[np.squeeze(z==1),:]
    yr2 = y[z==1]
    clf1 = Ridge(alpha=lamda, fit_intercept=False)
    clf1.fit(X1, yr1)
    clf2 = Ridge(alpha=lamda, fit_intercept=False)
    clf2.fit(X2, yr2)
    y0hat = np.expand_dims(clf1.predict(X), axis = 1)
    y1hat = np.expand_dims(clf2.predict(X), axis = 1)
    return y1hat - y0hat


def subsample_rand_vec_ite(X, y0, y1, params):
    n, d = X.shape
    alpha = params["alpha"]
    zeta = max(np.linalg.norm(X, axis=1))
    y0 = np.expand_dims(y0, axis=1)
    y1 = np.expand_dims(y1, axis=1)
    lamda = 0 #100 * np.log(y0.shape[0]) * zeta / alpha
    z = np.random.choice([-1, 1], size=(y0.shape[0],1), p=[0.5, 0.5]) # Z+ is 1 and Z- is -1
    z2 = np.random.choice([1, 0], size=(y0.shape[0],1), p=[alpha, 1 - alpha]) # S is 1
    y = ((y1 - y0) + np.multiply(z, y0 + y1)) / 2
    XS = X[np.squeeze(z2==1), :] / alpha
    yS = y[np.squeeze(z2==1)] / alpha
    clf = Ridge(alpha=lamda, fit_intercept=False)
    clf.fit(XS, yS)
    return clf.predict(2 * X)


def baseline(X, y0, y1, params):
    n, d = X.shape
    zeta = max(np.linalg.norm(X, axis=1))
    y0 = np.expand_dims(y0, axis=1)
    y1 = np.expand_dims(y1, axis=1)
    lamda = 0 #100 * np.log(y0.shape[0]) * zeta * zeta
    y = y1 - y0
    clf = Ridge(alpha=lamda, fit_intercept=False)
    clf.fit(X, y)
    return clf.predict(X)

